# Copyright 2017 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Main user-facing classes and utility functions for loading MuJoCo models."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import contextlib
import ctypes
import os
import weakref

# Internal dependencies.

from absl import logging

from dm_control.mujoco.wrapper import util
from dm_control.mujoco.wrapper.mjbindings import constants
from dm_control.mujoco.wrapper.mjbindings import enums
from dm_control.mujoco.wrapper.mjbindings import functions
from dm_control.mujoco.wrapper.mjbindings import mjlib
from dm_control.mujoco.wrapper.mjbindings import types
from dm_control.mujoco.wrapper.mjbindings import wrappers

import six

_NULL = b"\00"
_FAKE_XML_FILENAME = b"model.xml"
_FAKE_BINARY_FILENAME = b"model.mjb"

# Global cache used to store finalizers for freeing ctypes pointers.
# Contains {pointer_address: weakref_object} pairs.
_FINALIZERS = {}

# Cache of ctypes-wrapped Python callback functions that are called from C. We
# need to retain references to all wrapped Python callbacks that are currently
# in use, otherwise they might be garbage collected before they are called.
_ACTIVE_PYTHON_CALLBACKS = {}


class Error(Exception):
  """Base class for MuJoCo exceptions."""
  pass


if constants.mjVERSION_HEADER != mjlib.mj_version():
  raise Error("MuJoCo library version ({0}) does not match header version "
              "({1})".format(constants.mjVERSION_HEADER, mjlib.mj_version()))

_REGISTERED = False
_ERROR_BUFSIZE = 1000

# This is used to keep track of the `MJMODEL` pointer that was most recently
# loaded by `_get_model_ptr_from_xml`. Only this model can be saved to XML.
_LAST_PARSED_MODEL_PTR = None

_NOT_LAST_PARSED_ERROR = (
    "Only the model that was most recently loaded from an XML file or string "
    "can be saved to an XML file.")


# NB: Python functions that are called from C are defined at module-level to
# ensure they won't be garbage-collected before they are called.
@ctypes.CFUNCTYPE(None, ctypes.c_char_p)
def _warning_callback(message):
  logging.warn(util.to_native_string(message))


@ctypes.CFUNCTYPE(None, ctypes.c_char_p)
def _error_callback(message):
  logging.fatal(util.to_native_string(message))


# Override MuJoCo's callbacks for handling warnings and errors.
mjlib.mju_user_warning = ctypes.c_void_p.in_dll(mjlib, "mju_user_warning")
mjlib.mju_user_error = ctypes.c_void_p.in_dll(mjlib, "mju_user_error")
mjlib.mju_user_warning.value = ctypes.cast(
    _warning_callback, ctypes.c_void_p).value
mjlib.mju_user_error.value = ctypes.cast(
    _error_callback, ctypes.c_void_p).value

# Decorator that wraps a Python callback with the signature
#     func(const_mjmodel_ptr, mjdata_ptr) -> None
# and returns a `ctypes.CFunctionType`.
_WRAP_PYFUNC = ctypes.CFUNCTYPE(None, ctypes.POINTER(types.MJMODEL),
                                ctypes.POINTER(types.MJDATA))


def _maybe_register_license(path=None):
  """Registers the MuJoCo license if not already registered.

  Args:
    path: Optional custom path to license key file.

  Raises:
    Error: If the license could not be registered.
  """
  global _REGISTERED
  if not _REGISTERED:
    if path is None:
      path = util.get_mjkey_path()
    result = mjlib.mj_activate(util.to_binary_string(path))
    if result == 1:
      _REGISTERED = True
    elif result == 0:
      raise Error("Could not register license.")
    else:
      raise Error("Unknown registration error (code: {})".format(result))


def _str2type(type_str):
  type_id = mjlib.mju_str2Type(util.to_binary_string(type_str))
  if not type_id:
    raise Error("{!r} is not a valid object type name.".format(type_str))
  return type_id


def _type2str(type_id):
  type_str_ptr = mjlib.mju_type2Str(type_id)
  if not type_str_ptr:
    raise Error("{!r} is not a valid object type ID.".format(type_id))
  return ctypes.string_at(type_str_ptr)


def set_callback(name, new_callback=None):
  """Sets a user-defined callback function to modify MuJoCo's behavior.

  Callback functions should have the following signature:
    func(const_mjmodel_ptr, mjdata_ptr) -> None

  Args:
    name: Name of the callback to set. Must be a field in
      `functions.function_pointers`.
    new_callback: The new callback. This can be one of the following:
      * A Python callable
      * A C function exposed by a `ctypes.CDLL` object
      * An integer specifying the address of a callback function
      * None, in which case any existing callback of that name is removed

  Returns:
    Either an integer specifying the address of the previous function used for
    this callback, or None if the callback has not already been overridden.

  Raises:
    ValueError: If `name` is not in `functions.function_pointers`.
  """
  if name not in functions.function_pointers._fields:
    raise ValueError("Invalid callback name: {!r}. Must be one of {!r}.".format(
        name, functions.function_pointers._fields))
  callback_ptr = getattr(functions.function_pointers, name)
  try:
    new_callback_ptr = ctypes.cast(new_callback, ctypes.c_void_p)
  except ctypes.ArgumentError:
    # Python callables must be wrapped before casting to `ctypes.c_void_p`.
    wrapped_callback = _WRAP_PYFUNC(new_callback)
    new_callback_ptr = ctypes.cast(wrapped_callback, ctypes.c_void_p)
    # We must retain a reference to the wrapped callback function, otherwise it
    # might be garbage collected before it is called.
    _ACTIVE_PYTHON_CALLBACKS[new_callback_ptr.value] = wrapped_callback

  old_callback_address = callback_ptr.value

  # If the old callback was a wrapped Python function then we remove it from the
  # cache of active callbacks so that it can be garbage collected.
  if old_callback_address in _ACTIVE_PYTHON_CALLBACKS:
    del _ACTIVE_PYTHON_CALLBACKS[old_callback_address]

  callback_ptr.value = new_callback_ptr.value
  return old_callback_address


@contextlib.contextmanager
def callback_context(name, new_callback=None):
  """Context manager that temporarily overrides a MuJoCo callback function.

  On exit, the callback will be restored to its original value (None if the
  callback was not already overridden when the context was entered).

  Args:
    name: Name of the callback to set. Must be a field in
      `mjbindings.function_pointers`.
    new_callback: The new callback. This can be one of the following:
      * A Python callable
      * A C function exposed by a `ctypes.CDLL` object
      * An integer specifying the address of a callback function
      * None, in which case any existing callback of that name is removed

  Yields:
    None
  """
  old_callback = set_callback(name, new_callback)
  try:
    yield
  finally:
    # Ensure that the callback is reset on exit, even if an exception is raised.
    set_callback(name, old_callback)


def get_schema():
  """Returns a string containing the schema used by the MuJoCo XML parser."""
  buf = ctypes.create_string_buffer(100000)
  mjlib.mj_printSchema(None, buf, len(buf), 0, 0)
  return buf.value


@contextlib.contextmanager
def _temporary_vfs(filenames_and_contents):
  """Creates a temporary VFS containing one or more files.

  Args:
    filenames_and_contents: A dict containing `{filename: contents}` pairs.

  Yields:
    A `types.MJVFS` instance.

  Raises:
    Error: If a file cannot be added to the VFS, or if an error occurs when
      looking up the filename.
  """
  vfs = types.MJVFS()
  mjlib.mj_defaultVFS(vfs)
  for filename, contents in six.iteritems(filenames_and_contents):
    filename = util.to_binary_string(filename)
    contents = util.to_binary_string(contents)
    _, extension = os.path.splitext(filename)
    # For XML files we need to append a NULL byte, otherwise MuJoCo's parser
    # can sometimes read past the end of the string. However, we should *not*
    # do this for other file types (in particular for STL meshes, where this
    # causes MuJoCo's compiler to complain that the file size is incorrect).
    append_null = extension.lower() == b".xml"
    num_bytes = len(contents) + append_null
    retcode = mjlib.mj_makeEmptyFileVFS(vfs, filename, num_bytes)
    if retcode == 1:
      raise Error("Failed to create {!r}: VFS is full.".format(filename))
    elif retcode == 2:
      raise Error("Failed to create {!r}: duplicate filename.".format(filename))
    file_index = mjlib.mj_findFileVFS(vfs, filename)
    if file_index == -1:
      raise Error("Could not find {!r} in the VFS".format(filename))
    vf = vfs.filedata[file_index]
    vf_as_char_arr = ctypes.cast(vf, ctypes.POINTER(ctypes.c_char * num_bytes))
    vf_as_char_arr.contents[:len(contents)] = contents
    if append_null:
      vf_as_char_arr.contents[-1] = _NULL
  try:
    yield vfs
  finally:
    mjlib.mj_deleteVFS(vfs)  # Ensure that we free the VFS afterwards.


def _create_finalizer(ptr, free_func):
  """Creates a finalizer for a ctypes pointer.

  Args:
    ptr: A `ctypes.POINTER` to be freed.
    free_func: A callable that frees the pointer. It will be called with `ptr`
      as its only argument when `ptr` is garbage collected.
  """
  ptr_type = type(ptr)
  address = ctypes.addressof(ptr)

  if address not in _FINALIZERS:  # Only one finalizer needed per address.

    def callback(dead_ptr_ref):
      del dead_ptr_ref  # Unused weakref to the dead ctypes pointer object.
      # Temporarily resurrect the dead pointer so that we can free it.
      temp_ptr = ptr_type.from_address(address)
      logging.debug("Freeing %s", temp_ptr)
      free_func(temp_ptr)
      del _FINALIZERS[address]  # Remove the weakref from the global cache.

    # Store weakrefs in a global cache so that they don't get garbage collected
    # before their referents.
    _FINALIZERS[address] = weakref.ref(ptr, callback)


def _load_xml(filename, vfs_or_none):
  """Invokes `mj_loadXML` with logging/error handling."""
  error_buf = ctypes.create_string_buffer(_ERROR_BUFSIZE)
  model_ptr = mjlib.mj_loadXML(
      util.to_binary_string(filename),
      vfs_or_none,
      error_buf,
      _ERROR_BUFSIZE)
  if not model_ptr:
    raise Error(util.to_native_string(error_buf.value))
  elif error_buf.value:
    logging.warn(util.to_native_string(error_buf.value))

  # Free resources when the ctypes pointer is garbage collected.
  _create_finalizer(model_ptr, mjlib.mj_deleteModel)

  return model_ptr


def _get_model_ptr_from_xml(xml_path=None, xml_string=None, assets=None):
  """Parses a model XML file, compiles it, and returns a pointer to an mjModel.

  Args:
    xml_path: Path to a model XML file in MJCF or URDF format.
    xml_string: XML string containing an MJCF or URDF model description.
    assets: Optional dict containing external assets referenced by the model
      (such as additional XML files, textures, meshes etc.), in the form of
      `{filename: contents_string}` pairs. The keys should correspond to the
      filenames specified in the model XML. Ignored if `xml_string` is None.

    One of `xml_path` or `xml_string` must be specified.

  Returns:
    A `ctypes.POINTER` to a new `mjbindings.types.MJMODEL` instance.

  Raises:
    TypeError: If both or neither of `xml_path` and `xml_string` are specified.
    Error: If the model is not created successfully.
  """
  if xml_path is None and xml_string is None:
    raise TypeError(
        "At least one of `xml_path` or `xml_string` must be specified.")
  elif xml_path is not None and xml_string is not None:
    raise TypeError(
        "Only one of `xml_path` or `xml_string` may be specified.")

  _maybe_register_license()

  if xml_string is not None:
    assets = {} if assets is None else assets.copy()
    # Ensure that the fake XML filename doesn't overwrite an existing asset.
    xml_path = _FAKE_XML_FILENAME
    while xml_path in assets:
      xml_path = "_" + xml_path
    assets[xml_path] = xml_string
    with _temporary_vfs(assets) as vfs:
      ptr = _load_xml(xml_path, vfs)
  else:
    ptr = _load_xml(xml_path, None)

  global _LAST_PARSED_MODEL_PTR
  _LAST_PARSED_MODEL_PTR = ptr

  return ptr


def save_last_parsed_model_to_xml(xml_path, check_model=None):
  """Writes a description of the most recently loaded model to an MJCF XML file.

  Args:
    xml_path: Path to the output XML file.
    check_model: Optional `MjModel` instance. If specified, this model will be
      checked to see if it is the most recently parsed one, and a ValueError
      will be raised otherwise.
  Raises:
    Error: If MuJoCo encounters an error while writing the XML file.
    ValueError: If `check_model` was passed, and this model is not the most
      recently parsed one.
  """
  if check_model and check_model.ptr is not _LAST_PARSED_MODEL_PTR:
    raise ValueError(_NOT_LAST_PARSED_ERROR)
  error_buf = ctypes.create_string_buffer(_ERROR_BUFSIZE)
  mjlib.mj_saveLastXML(util.to_binary_string(xml_path),
                       _LAST_PARSED_MODEL_PTR,
                       error_buf,
                       _ERROR_BUFSIZE)
  if error_buf.value:
    raise Error(error_buf.value)


def _get_model_ptr_from_binary(binary_path=None, byte_string=None):
  """Returns a pointer to an mjModel from the contents of a MuJoCo model binary.

  Args:
    binary_path: Path to an MJB file (as produced by MjModel.save_binary).
    byte_string: String of bytes (as returned by MjModel.to_bytes).

    One of `binary_path` or `byte_string` must be specified.

  Returns:
    A `ctypes.POINTER` to a new `mjbindings.types.MJMODEL` instance.

  Raises:
    TypeError: If both or neither of `byte_string` and `binary_path`
      are specified.
  """
  if binary_path is None and byte_string is None:
    raise TypeError(
        "At least one of `byte_string` or `binary_path` must be specified.")
  elif binary_path is not None and byte_string is not None:
    raise TypeError(
        "Only one of `byte_string` or `binary_path` may be specified.")

  _maybe_register_license()

  if byte_string is not None:
    with _temporary_vfs({_FAKE_BINARY_FILENAME: byte_string}) as vfs:
      ptr = mjlib.mj_loadModel(_FAKE_BINARY_FILENAME, vfs)
  else:
    ptr = mjlib.mj_loadModel(util.to_binary_string(binary_path), None)

  # Free resources when the ctypes pointer is garbage collected.
  _create_finalizer(ptr, mjlib.mj_deleteModel)

  return ptr


# Subclasses implementing constructors/destructors for low-level wrappers.
# ------------------------------------------------------------------------------


class MjModel(wrappers.MjModelWrapper):
  """Wrapper class for a MuJoCo 'mjModel' instance.

  MjModel encapsulates features of the model that are expected to remain
  constant. It also contains simulation and visualization options which may be
  changed occasionally, although this is done explicitly by the user.
  """

  def __init__(self, model_ptr):
    """Creates a new MjModel instance from a ctypes pointer.

    Args:
      model_ptr: A `ctypes.POINTER` to an `mjbindings.types.MJMODEL` instance.
    """
    super(MjModel, self).__init__(model_ptr)

  def __getstate__(self):
    # All of MjModel's state is assumed to reside within the MuJoCo C struct.
    # However there is no mechanism to prevent users from adding arbitrary
    # Python attributes to an MjModel instance - these would not be serialized.
    return self.to_bytes()

  def __setstate__(self, byte_string):
    model_ptr = _get_model_ptr_from_binary(byte_string=byte_string)
    self.__init__(model_ptr)

  def __copy__(self):
    new_model_ptr = mjlib.mj_copyModel(None, self.ptr)
    return self.__class__(new_model_ptr)

  @classmethod
  def from_xml_string(cls, xml_string, assets=None):
    """Creates an `MjModel` instance from a model description XML string.

    Args:
      xml_string: String containing an MJCF or URDF model description.
      assets: Optional dict containing external assets referenced by the model
        (such as additional XML files, textures, meshes etc.), in the form of
        `{filename: contents_string}` pairs. The keys should correspond to the
        filenames specified in the model XML.

    Returns:
      An `MjModel` instance.
    """
    model_ptr = _get_model_ptr_from_xml(xml_string=xml_string, assets=assets)
    return cls(model_ptr)

  @classmethod
  def from_byte_string(cls, byte_string):
    """Creates an MjModel instance from a model binary as a string of bytes."""
    model_ptr = _get_model_ptr_from_binary(byte_string=byte_string)
    return cls(model_ptr)

  @classmethod
  def from_xml_path(cls, xml_path):
    """Creates an MjModel instance from a path to a model XML file."""
    model_ptr = _get_model_ptr_from_xml(xml_path=xml_path)
    return cls(model_ptr)

  @classmethod
  def from_binary_path(cls, binary_path):
    """Creates an MjModel instance from a path to a compiled model binary."""
    model_ptr = _get_model_ptr_from_binary(binary_path=binary_path)
    return cls(model_ptr)

  def save_binary(self, binary_path):
    """Saves the MjModel instance to a binary file."""
    mjlib.mj_saveModel(self.ptr, util.to_binary_string(binary_path), None, 0)

  def to_bytes(self):
    """Serialize the model to a string of bytes."""
    bufsize = mjlib.mj_sizeModel(self.ptr)
    buf = ctypes.create_string_buffer(bufsize)
    mjlib.mj_saveModel(self.ptr, None, buf, bufsize)
    return buf.raw

  def copy(self):
    """Returns a copy of this MjModel instance."""
    return self.__copy__()

  def name2id(self, name, object_type):
    """Returns the integer ID of a specified MuJoCo object.

    Args:
      name: String specifying the name of the object to query.
      object_type: The type of the object. Can be either a lowercase string
        (e.g. 'body', 'geom') or an `mjtObj` enum value.

    Returns:
      An integer object ID.

    Raises:
      Error: If `object_type` is not a valid MuJoCo object type, or if no object
        with the corresponding name and type was found.
    """
    if not isinstance(object_type, int):
      object_type = _str2type(object_type)
    obj_id = mjlib.mj_name2id(
        self.ptr, object_type, util.to_binary_string(name))
    if obj_id == -1:
      raise Error("Object of type {!r} with name {!r} does not exist.".format(
          _type2str(object_type), name))
    return obj_id

  def id2name(self, object_id, object_type):
    """Returns the name associated with a MuJoCo object ID, if there is one.

    Args:
      object_id: Integer ID.
      object_type: The type of the object. Can be either a lowercase string
        (e.g. 'body', 'geom') or an `mjtObj` enum value.

    Returns:
      A string containing the object name, or an empty string if the object ID
      either doesn't exist or has no name.

    Raises:
      Error: If `object_type` is not a valid MuJoCo object type.
    """
    if not isinstance(object_type, int):
      object_type = _str2type(object_type)
    name_ptr = mjlib.mj_id2name(self.ptr, object_type, object_id)
    if not name_ptr:
      return ""
    return util.to_native_string(ctypes.string_at(name_ptr))

  @contextlib.contextmanager
  def disable(self, *flags):
    """Context manager for temporarily disabling MuJoCo flags.

    Args:
      *flags: Positional arguments specifying flags to disable. Can be either
        lowercase strings (e.g. 'gravity', 'contact') or `mjtDisableBit` enum
        values.

    Yields:
      None

    Raises:
      ValueError: If any item in `flags` is neither a valid name nor a value
        from `enums.mjtDisableBit`.
    """
    old_bitmask = self.opt.disableflags
    new_bitmask = old_bitmask
    for flag in flags:
      if isinstance(flag, six.string_types):
        try:
          field_name = "mjDSBL_" + flag.upper()
          bitmask = getattr(enums.mjtDisableBit, field_name)
        except AttributeError:
          valid_names = [field_name.split("_")[1].lower()
                         for field_name in enums.mjtDisableBit._fields[:-1]]
          raise ValueError("'{}' is not a valid flag name. Valid names: {}"
                           .format(flag, ", ".join(valid_names)))
      else:
        if flag not in enums.mjtDisableBit[:-1]:
          raise ValueError("'{}' is not a value in `enums.mjtDisableBit`. "
                           "Valid values: {}"
                           .format(flag, tuple(enums.mjtDisableBit[:-1])))
        bitmask = flag
      new_bitmask |= bitmask
    self.opt.disableflags = new_bitmask
    try:
      yield
    finally:
      self.opt.disableflags = old_bitmask

  @property
  def name(self):
    """Returns the name of the model."""
    # The model name is the first null-terminated string in the `names` buffer.
    return util.to_native_string(
        ctypes.string_at(ctypes.addressof(self.names.contents)))


class MjData(wrappers.MjDataWrapper):
  """Wrapper class for a MuJoCo 'mjData' instance.

  MjData contains all of the dynamic variables and intermediate results produced
  by the simulation. These are expected to change on each simulation timestep.
  """

  def __init__(self, model):
    """Construct a new MjData instance.

    Args:
      model: An MjModel instance.
    """
    self._model = model

    # Allocate resources for mjData.
    data_ptr = mjlib.mj_makeData(model.ptr)

    # Free resources when the ctypes pointer is garbage collected.
    _create_finalizer(data_ptr, mjlib.mj_deleteData)

    super(MjData, self).__init__(data_ptr, model)

  def __getstate__(self):
    # Note: we can replace this once a `saveData` MJAPI function exists.
    # To reconstruct an MjData instance we need three things:
    #   1. Its parent MjModel instance
    #   2. A subset of its fixed-size fields whose values aren't determined by
    #      the model
    #   3. The contents of its internal buffer (all of its pointer fields point
    #      into this)
    struct_fields = {}
    for name in ["solver", "timer", "warning"]:
      new_structs = []
      for struct in getattr(self, name):
        new_struct = type(struct)()
        ctypes.memmove(ctypes.byref(new_struct), ctypes.byref(struct),
                       ctypes.sizeof(struct))
        new_structs.append(new_struct)
      struct_fields[name] = new_structs
    scalar_field_names = ["ncon", "time", "energy"]
    scalar_fields = {name: getattr(self, name) for name in scalar_field_names}
    static_fields = {"struct_fields": struct_fields,
                     "scalar_fields": scalar_fields}
    buffer_contents = ctypes.string_at(self.buffer_, self.nbuffer)
    return (self._model, static_fields, buffer_contents)

  def __setstate__(self, state_tuple):
    # Replace this once a `loadData` MJAPI function exists.
    self._model, static_fields, buffer_contents = state_tuple
    self.__init__(self.model)
    for name, contents in six.iteritems(static_fields["struct_fields"]):
      target_carray = getattr(self, name)
      for i, struct in enumerate(contents):
        ctypes.memmove(ctypes.byref(target_carray[i]), ctypes.byref(struct),
                       ctypes.sizeof(struct))

    for name, value in six.iteritems(static_fields["scalar_fields"]):
      # Array and scalar values must be handled separately.
      try:
        getattr(self, name)[:] = value
      except TypeError:
        setattr(self, name, value)
    buf_ptr = (ctypes.c_char * self.nbuffer).from_address(self.buffer_)
    buf_ptr[:] = buffer_contents

  def __copy__(self):
    # This makes a shallow copy that shares the same parent MjModel instance.
    new_obj = self.__class__(self.model)
    mjlib.mj_copyData(new_obj.ptr, self.model.ptr, self.ptr)
    return new_obj

  def copy(self):
    """Returns a copy of this MjData instance with the same parent MjModel."""
    return self.__copy__()

  @property
  def model(self):
    """The parent MjModel for this MjData instance."""
    return self._model

  @property
  def contact(self):
    """Iterator over detected contacts."""
    return (wrappers.MjContactWrapper(ctypes.pointer(c))
            for c in super(MjData, self).contact[:self.ncon])


# Docstrings for these subclasses are inherited from their Wrapper parent class.


class MjvCamera(wrappers.MjvCameraWrapper):

  def __init__(self):
    ptr = ctypes.pointer(types.MJVCAMERA())
    mjlib.mjv_defaultCamera(ptr)
    super(MjvCamera, self).__init__(ptr)


class MjvOption(wrappers.MjvOptionWrapper):

  def __init__(self):
    ptr = ctypes.pointer(types.MJVOPTION())
    mjlib.mjv_defaultOption(ptr)
    super(MjvOption, self).__init__(ptr)


class MjrContext(wrappers.MjrContextWrapper):

  def __init__(self):
    ptr = ctypes.pointer(types.MJRCONTEXT())
    mjlib.mjr_defaultContext(ptr)
    super(MjrContext, self).__init__(ptr)


class MjvScene(wrappers.MjvSceneWrapper):  # pylint: disable=missing-docstring

  def __init__(self, max_geom=1000):
    """Initializes a new `MjvScene` instance.

    Args:
      max_geom: (optional) An integer specifying the maximum number of geoms
        that can be represented in the scene.
    """
    scene_ptr = ctypes.pointer(types.MJVSCENE())

    # Allocate and initialize resources for the abstract scene.
    mjlib.mjv_makeScene(scene_ptr, max_geom)

    # Free resources when the ctypes pointer is garbage collected.
    _create_finalizer(scene_ptr, mjlib.mjv_freeScene)

    super(MjvScene, self).__init__(scene_ptr)


class MjvPerturb(wrappers.MjvPerturbWrapper):

  def __init__(self):
    ptr = ctypes.pointer(types.MJVPERTURB())
    mjlib.mjv_defaultPerturb(ptr)
    super(MjvPerturb, self).__init__(ptr)


class MjvFigure(wrappers.MjvFigureWrapper):

  def __init__(self):
    ptr = ctypes.pointer(types.MJVFIGURE())
    mjlib.mjv_defaultFigure(ptr)
    super(MjvFigure, self).__init__(ptr)
